{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to Fine-Tune Multimodal Models or VLMs with Hugging Face TRL\n",
    "\n",
    "Multimodal LLMs are making tremendous progress recently. We now have a diverse ecosystem of powerful open Multimodal models, mostly Vision-Language Models (VLM), including Meta AI's [Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct), Mistral AI's [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409), Qwen's [Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B), and Allen AI's [Molmo-7B-D-0924](https://huggingface.co/allenai/Molmo-7B-D-0924).\n",
    "\n",
    "These VLMs can handle a variety of multimodal tasks, including image captioning, visual question answering, and image-text matching without additional training. However, to customize a model for your specific application, you may need to fine-tune it on your data to achieve higher quality results or to create a more efficient model for your use case.\n",
    "\n",
    "This blog post walks you through how to fine-tune open VLMs using Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index) in 2024. We'll cover:\n",
    "\n",
    "1. Define our multimodal use case \n",
    "1. Setup development environment\n",
    "2. Create and prepare the multimodal dataset\n",
    "3. Fine-tune VLM using `trl` and the `SFTTrainer` \n",
    "4. Test and evaluate the VLM\n",
    "\n",
    "_Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs._\n",
    "\n",
    "## 1. Define our multimodal use case \n",
    "\n",
    "When fine-tuning VLMs, it's crucial to clearly define your use case and the multimodal task you want to solve. This will guide your choice of base model and help you create an appropriate dataset for fine-tuning. If you haven't defined your use case yet, you might want to revisit your requirements.\n",
    "\n",
    "It's worth noting that for most use cases fine-tuning might not be the first option. We recommend evaluating pre-trained models or API-based solutions before committing to fine-tuning your own model.\n",
    "\n",
    "As an example, we'll use the following multimodal use case:\n",
    "\n",
    "> We want to fine-tune a model that can generate detailed product descriptions based on product images and basic metadata. This model will be integrated into our e-commerce platform to help sellers create more compelling listings. The goal is to reduce the time it takes to create product descriptions and improve their quality and consistency.\n",
    "\n",
    "Existing models might already be very good for this use case, but you might want to tweak/tune it to your specific needs. This image-to-text generation task is well-suited for fine-tuning VLMs, as it requires understanding visual features and combining them with textual information to produce coherent and relevant descriptions. I created a test dataset for this use case using Gemini 1.5 [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Setup development environment\n",
    "\n",
    "Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Pytorch & other libraries\n",
    "%pip install \"torch==2.4.0\" tensorboard pillow\n",
    "\n",
    "# Install Hugging Face libraries\n",
    "%pip install  --upgrade \\\n",
    "  \"transformers==4.45.1\" \\\n",
    "  \"datasets==3.0.1\" \\\n",
    "  \"accelerate==0.34.2\" \\\n",
    "  \"evaluate==0.4.3\" \\\n",
    "  \"bitsandbytes==0.44.0\" \\\n",
    "  \"trl==0.11.1\" \\\n",
    "  \"peft==0.13.0\" \\\n",
    "  \"qwen-vl-utils\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the [Hugging Face Hub](https://huggingface.co/models) as a remote model versioning service. This means we will automatically push our model, logs and information to the Hub during training. You must register on the [Hugging Face](https://huggingface.co/join) for this. After you have an account, we will use the `login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login(\n",
    "  token=\"\", # ADD YOUR TOKEN HERE\n",
    "  add_to_git_credential=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Create and prepare the dataset\n",
    "\n",
    "Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. We have to prepare the dataset in a format that the model can understand. \n",
    "\n",
    "In our example we will use [philschmid/amazon-product-descriptions-vlm](https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm), which contains 1,350 amazon products with title, images and descriptions and metadata. We want to fine-tune our model to generate product descriptions based on the images, title and metadata. Therefore we need to create a prompt including the title, metadata and image and the completion is the description. \n",
    "\n",
    "TRL supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. \n",
    "```json\n",
    "\"messages\": [\n",
    "  {\"role\": \"system\", \"content\": [{\"type\":\"text\", \"text\": \"You are a helpful....\"}]},\n",
    "  {\"role\": \"user\", \"content\": [{\n",
    "    \"type\": \"text\", \"text\":  \"How many dogs are in the image?\", \n",
    "    \"type\": \"image\", \"text\": <PIL.Image> \n",
    "    }]},\n",
    "  {\"role\": \"assistant\", \"content\": [{\"type\":\"text\", \"text\": \"There are 3 dogs in the image.\"}]}\n",
    "],\n",
    "```\n",
    "\n",
    "In our example we are going to load our dataset using the Datasets library and apply our frompt and convert it into the the conversational format. \n",
    "\n",
    "Lets start with defining our instruction prompt. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# note the image is not provided in the prompt its included as part of the \"processor\"\n",
    "prompt= \"\"\"Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image. \n",
    "Only return description. The description should be SEO optimized and for a better mobile search experience.\n",
    "\n",
    "##PRODUCT NAME##: {product_name}\n",
    "##CATEGORY##: {category}\"\"\"\n",
    "\n",
    "system_message = \"You are an expert product description writer for Amazon.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can format our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Convert dataset to OAI messages       \n",
    "def format_data(sample):\n",
    "    return {\"messages\": [\n",
    "                {\n",
    "                    \"role\": \"system\",\n",
    "                    \"content\": [{\"type\": \"text\", \"text\": system_message}],\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": prompt.format(product_name=sample[\"Product Name\"], category=sample[\"Category\"]),\n",
    "                        },{\n",
    "                            \"type\": \"image\",\n",
    "                            \"image\": sample[\"image\"],\n",
    "                        }\n",
    "                    ],\n",
    "                },\n",
    "                {\n",
    "                    \"role\": \"assistant\",\n",
    "                    \"content\": [{\"type\": \"text\", \"text\": sample[\"description\"]}],\n",
    "                },\n",
    "            ],\n",
    "        }\n",
    "\n",
    "# Load dataset from the hub\n",
    "dataset_id = \"philschmid/amazon-product-descriptions-vlm\"\n",
    "dataset = load_dataset(\"philschmid/amazon-product-descriptions-vlm\", split=\"train\")\n",
    "\n",
    "# Convert dataset to OAI messages\n",
    "# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes\n",
    "dataset = [format_data(sample) for sample in dataset]\n",
    "\n",
    "print(dataset[345][\"messages\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Fine-tune VLM using `trl` and the `SFTTrainer` \n",
    "\n",
    "We are now ready to fine-tune our model. We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs and VLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features.\n",
    " \n",
    "We will use the PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization. If you want to learn more about QLoRA and how it works, check out [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.\n",
    "\n",
    "\n",
    "_Note: We cannot use Flash Attention as we need to pad our multimodal inputs._\n",
    "\n",
    "We are going to use Qwen 2 VL 7B model, but we can easily swap out the model for another model, including Meta AI's [Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct), Mistral AI's [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409) or any other LLMs by changing our `model_id` variable. We will use bitsandbytes to quantize our model to 4-bit.\n",
    "\n",
    "_Note: Be aware the bigger the model the more memory it will require. In our example we will use the 7B version, which can be tuned on 24GB GPUs._\n",
    "\n",
    "Correctly, preparing the LLM, Tokenizer and Processor for training VLMs is crucial. The Processor is responsible for including the special tokens and image features in the input. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig\n",
    "\n",
    "# Hugging Face model id\n",
    "model_id = \"Qwen/Qwen2-VL-7B-Instruct\" \n",
    "\n",
    "# BitsAndBytesConfig int-4 config\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "# Load model and tokenizer\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "    model_id,\n",
    "    device_map=\"auto\",\n",
    "    # attn_implementation=\"flash_attention_2\", # not supported for training\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    quantization_config=bnb_config\n",
    ")\n",
    "processor = AutoProcessor.from_pretrained(model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preparation for inference\n",
    "text = processor.apply_chat_template(\n",
    "    dataset[2][\"messages\"], tokenize=False, add_generation_prompt=False\n",
    ")\n",
    "print(text)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. Our `LoraConfig` parameters are defined based on the [qlora paper](https://arxiv.org/pdf/2305.14314.pdf) and sebastian's [blog post](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig\n",
    "\n",
    "# LoRA config based on QLoRA paper & Sebastian Raschka experiment\n",
    "peft_config = LoraConfig(\n",
    "        lora_alpha=16,\n",
    "        lora_dropout=0.05,\n",
    "        r=8,\n",
    "        bias=\"none\",\n",
    "        target_modules=[\"q_proj\", \"v_proj\"],\n",
    "        task_type=\"CAUSAL_LM\", \n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we can start our training we need to define the hyperparameters (`SFTConfig`) we want to use and make sure our inputs are correcty provided to the model. Different to text-only supervised fine-tuning we need to provide the image to the model as well. Therefore we create a custom `DataCollator` which formates the inputs correctly and include the image features. We use the [process_vision_info](https://github.com/QwenLM/Qwen2-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L321) method from a utility package the Qwen2 team provides. If you are using another model, e.g. Llama 3.2 Vision you might have to check if that creates the same processsed image information. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTConfig\n",
    "from transformers import Qwen2VLProcessor\n",
    "from qwen_vl_utils import process_vision_info\n",
    "\n",
    "args = SFTConfig(\n",
    "    output_dir=\"qwen2-7b-instruct-amazon-description\", # directory to save and repository id\n",
    "    num_train_epochs=3,                     # number of training epochs\n",
    "    per_device_train_batch_size=4,          # batch size per device during training\n",
    "    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass\n",
    "    gradient_checkpointing=True,            # use gradient checkpointing to save memory\n",
    "    optim=\"adamw_torch_fused\",              # use fused adamw optimizer\n",
    "    logging_steps=5,                       # log every 10 steps\n",
    "    save_strategy=\"epoch\",                  # save checkpoint every epoch\n",
    "    learning_rate=2e-4,                     # learning rate, based on QLoRA paper\n",
    "    bf16=True,                              # use bfloat16 precision\n",
    "    tf32=True,                              # use tf32 precision\n",
    "    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper\n",
    "    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper\n",
    "    lr_scheduler_type=\"constant\",           # use constant learning rate scheduler\n",
    "    push_to_hub=True,                       # push model to hub\n",
    "    report_to=\"tensorboard\",                # report metrics to tensorboard\n",
    "    gradient_checkpointing_kwargs = {\"use_reentrant\": False}, # use reentrant checkpointing\n",
    "    dataset_text_field=\"\", # need a dummy field for collator\n",
    "    dataset_kwargs = {\"skip_prepare_dataset\": True} # important for collator\n",
    ")\n",
    "args.remove_unused_columns=False\n",
    "\n",
    "# Create a data collator to encode text and image pairs\n",
    "def collate_fn(examples):\n",
    "    # Get the texts and images, and apply the chat template\n",
    "    texts = [processor.apply_chat_template(example[\"messages\"], tokenize=False) for example in examples]\n",
    "    image_inputs = [process_vision_info(example[\"messages\"])[0] for example in examples]\n",
    "\n",
    "    # Tokenize the texts and process the images\n",
    "    batch = processor(text=texts, images=image_inputs, return_tensors=\"pt\", padding=True)\n",
    "\n",
    "    # The labels are the input_ids, and we mask the padding tokens in the loss computation\n",
    "    labels = batch[\"input_ids\"].clone()\n",
    "    labels[labels == processor.tokenizer.pad_token_id] = -100  #\n",
    "    # Ignore the image token index in the loss computation (model specific)\n",
    "    if isinstance(processor, Qwen2VLProcessor):\n",
    "        image_tokens = [151652,151653,151655]\n",
    "    else: \n",
    "        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]\n",
    "    for image_token_id in image_tokens:\n",
    "        labels[labels == image_token_id] = -100\n",
    "    batch[\"labels\"] = labels\n",
    "\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have every building block we need to create our `SFTTrainer` to start then training our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTTrainer\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=dataset,\n",
    "    data_collator=collate_fn,\n",
    "    dataset_text_field=\"\", # needs dummy value\n",
    "    peft_config=peft_config,\n",
    "    tokenizer=processor.tokenizer,\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start training, the model will be automatically saved to the hub and the output directory\n",
    "trainer.train()\n",
    "\n",
    "# save model \n",
    "trainer.save_model(args.output_dir)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training for 3 epochs with a dataset of ~1k samples took 01:31:58 on a `g6.2xlarge`. The instance costs `0.9776$/h` which brings us to a total cost of only `1.4$`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# free the memory again\n",
    "del model\n",
    "del trainer\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Test Model and run Inference\n",
    "\n",
    "After the training is done we want to evaluate and test our model. First we will load the base model and let it generate a description for a random Amazon product. Then we will load our Q-LoRA adapted model and let it generate a description for the same product. \n",
    "\n",
    "Finally we can merge the adapter into the base model to make it more efficient and run inference on the same product again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoProcessor, AutoModelForVision2Seq\n",
    "\n",
    "adapter_path = \"./qwen2-7b-instruct-amazon-description\"\n",
    "\n",
    "# Load Model base model\n",
    "model = AutoModelForVision2Seq.from_pretrained(\n",
    "  model_id,\n",
    "  device_map=\"auto\",\n",
    "  torch_dtype=torch.float16\n",
    ")\n",
    "processor = AutoProcessor.from_pretrained(model_id)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I selected a random product from Amazon and prepared a `generate_description` function to generate a description for the product."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qwen_vl_utils import process_vision_info\n",
    "\n",
    "# sample from amazon.com\n",
    "sample = {\n",
    "  \"product_name\": \"Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur\",\n",
    "  \"catergory\": \"Toys & Games | Toy Figures & Playsets | Action Figures\",\n",
    "  \"image\": \"https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg\"\n",
    "}\n",
    "\n",
    "# prepare message\n",
    "messages = [{\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\n",
    "                \"type\": \"image\",\n",
    "                \"image\": sample[\"image\"],\n",
    "            },\n",
    "            {\"type\": \"text\", \"text\": prompt.format(product_name=sample[\"product_name\"], category=sample[\"catergory\"])},\n",
    "        ],\n",
    "    }\n",
    "]\n",
    "\n",
    "def generate_description(sample, model, processor):\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": [{\"type\": \"text\", \"text\": system_message}]},\n",
    "        {\"role\": \"user\", \"content\": [\n",
    "            {\"type\": \"image\",\"image\": sample[\"image\"]},\n",
    "            {\"type\": \"text\", \"text\": prompt.format(product_name=sample[\"product_name\"], category=sample[\"catergory\"])},\n",
    "        ]},\n",
    "    ]\n",
    "    # Preparation for inference\n",
    "    text = processor.apply_chat_template(\n",
    "        messages, tokenize=False, add_generation_prompt=True\n",
    "    )\n",
    "    image_inputs, video_inputs = process_vision_info(messages)\n",
    "    inputs = processor(\n",
    "        text=[text],\n",
    "        images=image_inputs,\n",
    "        videos=video_inputs,\n",
    "        padding=True,\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    inputs = inputs.to(model.device)\n",
    "    # Inference: Generation of the output\n",
    "    generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8)\n",
    "    generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]\n",
    "    output_text = processor.batch_decode(\n",
    "        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
    "    )\n",
    "    return output_text[0]\n",
    "\n",
    "# let's generate the description\n",
    "base_description = generate_description(sample, model, processor)\n",
    "print(base_description)\n",
    "# you can disable the active adapter if you want to rerun it with\n",
    "# model.disable_adapters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Awesome it is working! Lets load our adapter and compare if with the base model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_adapter(adapter_path) # load the adapter and activate\n",
    "\n",
    "ft_description = generate_description(sample, model, processor)\n",
    "print(ft_description)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets compare them side by side using a markdown table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_031db_row0_col0, #T_031db_row0_col1 {\n",
       "  text-align: left;\n",
       "  white-space: pre-wrap;\n",
       "  border: 1px solid black;\n",
       "  padding: 10px;\n",
       "  width: 250px;\n",
       "  overflow-wrap: break-word;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_031db\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_031db_level0_col0\" class=\"col_heading level0 col0\" >Base Generation</th>\n",
       "      <th id=\"T_031db_level0_col1\" class=\"col_heading level0 col1\" >Fine-tuned Generation</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_031db_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "      <td id=\"T_031db_row0_col0\" class=\"data row0 col0\" >Introducing the Hasbro Marvel Avengers Series Marvel Assemble Titan-Held Iron Man Action Figure, a 30.5 cm tall action figure that is sure to bring the excitement of the Marvel Universe to life! This highly detailed Iron Man figure is perfect for fans of all ages and makes a great addition to any toy collection. With its sleek red and gold armor, this Iron Man figure is ready to take on any challenge. The Titan-Held feature allows for a more realistic and dynamic pose, making it a must-have for any Marvel fan. Whether you're a collector or just looking for a fun toy to play with, this Iron Man action figure is the perfect choice.</td>\n",
       "      <td id=\"T_031db_row0_col1\" class=\"data row0 col1\" >Unleash the power of Iron Man with this 30.5 cm Hasbro Marvel Avengers Titan Hero Action Figure!  This highly detailed Iron Man figure is perfect for collectors and kids alike.  Features a realistic design and articulated joints for dynamic poses.  A must-have for any Marvel fan's collection!</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from IPython.display import display, HTML\n",
    "\n",
    "def compare_generations(base_gen, ft_gen):\n",
    "    # Create a DataFrame\n",
    "    df = pd.DataFrame({\n",
    "        'Base Generation': [base_gen],\n",
    "        'Fine-tuned Generation': [ft_gen]\n",
    "    })\n",
    "    # Style the DataFrame\n",
    "    styled_df = df.style.set_properties(**{\n",
    "        'text-align': 'left',\n",
    "        'white-space': 'pre-wrap',\n",
    "        'border': '1px solid black',\n",
    "        'padding': '10px',\n",
    "        'width': '250px',  # Set width to 150px\n",
    "        'overflow-wrap': 'break-word'  # Allow words to break and wrap as needed\n",
    "    })\n",
    "    \n",
    "    # Display the styled DataFrame\n",
    "    display(HTML(styled_df.to_html()))\n",
    "  \n",
    "compare_generations(base_description, ft_description)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nice! Even though we just had ~1k samples we can see that the fine-tuning improve the product description generation. The description is way shorter and more concise, which fits our training data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optional: Merge LoRA adapter in to the original model\n",
    "\n",
    "When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the `save_pretrained` method. This will save a default model, which can be used for inference.\n",
    "\n",
    "_Note: This requires > 30GB CPU Memory._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel\n",
    "from transformers import AutoProcessor, AutoModelForVision2Seq\n",
    "\n",
    "adapter_path = \"./qwen2-7b-instruct-amazon-description\"\n",
    "base_model_id = \"Qwen/Qwen2-VL-7B-Instruct\"\n",
    "merged_path = \"merged\"\n",
    "\n",
    "# Load Model base model\n",
    "model = AutoModelForVision2Seq.from_pretrained(model_id, low_cpu_mem_usage=True)\n",
    "\n",
    "# Path to save the merged model\n",
    "\n",
    "# Merge LoRA and base model and save\n",
    "peft_model = PeftModel.from_pretrained(model, adapter_path)\n",
    "merged_model = peft_model.merge_and_unload()\n",
    "merged_model.save_pretrained(merged_path,safe_serialization=True, max_shard_size=\"2GB\")\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(base_model_id)\n",
    "processor.save_pretrained(merged_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus: Use TRL example script\n",
    "\n",
    "TRL provides a simple example script to fine-tune multimodal models. You can find the script [here](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py). The script can be directly run from the command line and supports all the features of the `SFTTrainer`.\n",
    "```bash\n",
    "# Tested on 8x H100 GPUs\n",
    "accelerate launch\n",
    "    --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \\\n",
    "    examples/scripts/sft_vlm.py \\\n",
    "    --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \\\n",
    "    --model_name_or_path llava-hf/llava-1.5-7b-hf \\\n",
    "    --per_device_train_batch_size 8 \\\n",
    "    --gradient_accumulation_steps 8 \\\n",
    "    --output_dir sft-llava-1.5-7b-hf \\\n",
    "    --bf16 \\\n",
    "    --torch_dtype bfloat16 \\\n",
    "    --gradient_checkpointing\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
